Program Listing for File lanenet_cluster.py
↰ Return to documentation for file (codes/lanekerbnetros/lanenet_model/lanenet_cluster.py)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# @Time : 17-05-2019
# @Author : Zhou Hui
# @Original site : https://github.com/MaybeShewill-CV/lanenet-lane-detection
# @File : lanenet_node.py
"""
Implementation of the clustering section of instance segmentation
"""
import numpy as np
import glog as log
import math
import matplotlib.pyplot as plt
from sklearn.cluster import MeanShift
from sklearn.cluster import DBSCAN
import time
import warnings
import cv2
try:
from cv2 import cv2
except ImportError:
pass
class LaneNetCluster(object):
def __init__(self):
self._color_map = [np.array([255, 0, 0]),
np.array([0, 255, 0]),
np.array([0, 0, 255]),
np.array([125, 125, 0]),
np.array([0, 125, 125]),
np.array([125, 0, 125]),
np.array([50, 100, 50]),
np.array([100, 50, 100])]
pass
@staticmethod
def _cluster(prediction, bandwidth):
ms = MeanShift(bandwidth, bin_seeding=True)
# log.info('Start Mean shift clustering ...')
tic = time.time()
try:
ms.fit(prediction)
except ValueError as err:
log.error(err)
return 0, [], []
# log.info('Mean Shift cost time: {:.5f}s'.format(time.time() - tic))
labels = ms.labels_
cluster_centers = ms.cluster_centers_
num_clusters = cluster_centers.shape[0]
# log.info('The number of clusters is: {:d}'.format(num_clusters))
return num_clusters, labels, cluster_centers
@staticmethod
def _cluster_v2(prediction):
db = DBSCAN(eps=0.7, min_samples=200).fit(prediction)
db_labels = db.labels_
unique_labels = np.unique(db_labels)
unique_labels = [tmp for tmp in unique_labels if tmp != -1]
log.info('The number of clusters is: {:d}'.format(len(unique_labels)))
num_clusters = len(unique_labels)
cluster_centers = db.components_
return num_clusters, db_labels, cluster_centers
@staticmethod
def _get_lane_area(binary_seg_ret, instance_seg_ret):
idx = np.where(binary_seg_ret == 1)
lane_embedding_feats = []
lane_coordinate = []
for i in range(len(idx[0])):
lane_embedding_feats.append(instance_seg_ret[idx[0][i], idx[1][i]])
lane_coordinate.append([idx[0][i], idx[1][i]])
return np.array(lane_embedding_feats, np.float32), np.array(lane_coordinate, np.int64)
@staticmethod
def _get_curb_area(binary_seg_ret, instance_seg_ret):
idx = np.where(binary_seg_ret == 2)
lane_embedding_feats = []
lane_coordinate = []
for i in range(len(idx[0])):
lane_embedding_feats.append(instance_seg_ret[idx[0][i], idx[1][i]])
lane_coordinate.append([idx[0][i], idx[1][i]])
return np.array(lane_embedding_feats, np.float32), np.array(lane_coordinate, np.int64)
@staticmethod
def _thresh_coord(coord):
pts_x = coord[:, 0]
mean_x = np.mean(pts_x)
idx = np.where(np.abs(pts_x - mean_x) < mean_x)
return coord[idx[0]]
@staticmethod
def _lane_fit(lane_pts):
if not isinstance(lane_pts, np.ndarray):
lane_pts = np.array(lane_pts, np.float32)
x = lane_pts[:, 1]
#print(x)
y = lane_pts[:, 0]
#print(y)
x_fit = []
y_fit = []
with warnings.catch_warnings():
warnings.filterwarnings('error')
try:
f1 = np.polyfit(x, y, 3)
p1 = np.poly1d(f1)
x_min = int(np.min(x))
x_max = int(np.max(x))
step = int(math.floor((x_max - x_min) / 5))
#print(step)
x_fit = []
for i in np.linspace(x_min, x_max, step):
#print(i)
x_fit.append(i)
#print(x_fit)
y_fit = p1(x_fit)
#print(y_fit)
except Warning as e:
x_fit = x
y_fit = y
finally:
return zip(y_fit, x_fit)
def get_lane_mask(self, binary_seg_ret, instance_seg_ret, source_image):
lane_embedding_feats, lane_coordinate = self._get_lane_area(binary_seg_ret, instance_seg_ret)
num_clusters, labels, cluster_centers = self._cluster(lane_embedding_feats, bandwidth=1.5)
# If there are more than eight clusters, the eight clusters with the most samples in the class are selected to remain.
if num_clusters > 8:
cluster_sample_nums = []
for i in range(num_clusters):
cluster_sample_nums.append(len(np.where(labels == i)[0]))
sort_idx = np.argsort(-np.array(cluster_sample_nums, np.int64))
cluster_index = np.array(range(num_clusters))[sort_idx[0:8]]
else:
cluster_index = range(num_clusters)
# mask_image = np.zeros(shape=[binary_seg_ret.shape[0], binary_seg_ret.shape[1], 3], dtype=np.uint8)
for index, i in enumerate(cluster_index):
idx = np.where(labels == i)
coord = lane_coordinate[idx]
# coord = self._thresh_coord(coord)
coord = np.flip(coord, axis=1) # flip horizontally
# coord = (coord[:, 0], coord[:, 1])
#print(coord)
# color = (int(self._color_map[index][0]),
# int(self._color_map[index][1]),
# int(self._color_map[index][2]))
color = (0, 255, 0)
#coord = np.array([coord])
coord_zip = self._lane_fit(coord)
coord_fit = np.array(list(coord_zip)).astype(int)
#coord = np.array([coord_fit])
#print(coord_fit)
for point in coord_fit:
cv2.circle(source_image, tuple(point), 3, color, -1)
#cv2.polylines(img=mask_image, pts=coord, isClosed=False, color=color, thickness=2)
#cv2.polylines(img=source_image, pts=coord, isClosed=False, color=color, thickness=2)
# mask_image[coord] = color
return source_image
def get_curb_mask(self, binary_seg_ret, instance_seg_ret, source_image):
lane_embedding_feats, lane_coordinate = self._get_curb_area(binary_seg_ret, instance_seg_ret)
num_clusters, labels, cluster_centers = self._cluster(lane_embedding_feats, bandwidth=1.5)
# If there are more than eight clusters, the eight clusters with the most samples in the class are selected to remain.
if num_clusters > 8:
cluster_sample_nums = []
for i in range(num_clusters):
cluster_sample_nums.append(len(np.where(labels == i)[0]))
sort_idx = np.argsort(-np.array(cluster_sample_nums, np.int64))
cluster_index = np.array(range(num_clusters))[sort_idx[0:8]]
else:
cluster_index = range(num_clusters)
#mask_image = np.zeros(shape=[binary_seg_ret.shape[0], binary_seg_ret.shape[1], 3], dtype=np.uint8)
for index, i in enumerate(cluster_index):
idx = np.where(labels == i)
coord = lane_coordinate[idx]
# coord = self._thresh_coord(coord)
coord = np.flip(coord, axis=1) # flip horizontally
# coord = (coord[:, 0], coord[:, 1])
#print(coord)
# color = (int(self._color_map[index][0]),
# int(self._color_map[index][1]),
# int(self._color_map[index][2]))
color = (0, 0, 255)
#coord = np.array([coord])
coord_zip = self._lane_fit(coord)
coord_fit = np.array(list(coord_zip)).astype(int)
#coord = np.array([coord_fit])
#print(coord_fit)
for point in coord_fit:
cv2.circle(source_image, tuple(point), 3, color, -1)
#cv2.polylines(img=mask_image, pts=coord, isClosed=False, color=color, thickness=2)
#cv2.polylines(img=source_image, pts=coord, isClosed=False, color=color, thickness=2)
# mask_image[coord] = color
return source_image
if __name__ == '__main__':
binary_seg_image = cv2.imread('binary_ret.png', cv2.IMREAD_GRAYSCALE)
binary_seg_image[np.where(binary_seg_image == 255)] = 1
instance_seg_image = cv2.imread('instance_ret.png', cv2.IMREAD_UNCHANGED)
ele_mex = np.max(instance_seg_image, axis=(0, 1))
for i in range(3):
if ele_mex[i] == 0:
scale = 1
else:
scale = 255 / ele_mex[i]
instance_seg_image[:, :, i] *= int(scale)
embedding_image = np.array(instance_seg_image, np.uint8)
cluster = LaneNetCluster()
mask_image = cluster.get_lane_mask(instance_seg_ret=instance_seg_image, binary_seg_ret=binary_seg_image)
plt.figure('embedding')
plt.imshow(embedding_image[:, :, (2, 1, 0)])
plt.figure('mask_image')
plt.imshow(mask_image[:, :, (2, 1, 0)])
plt.show()